{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e446da2d668c2e07",
   "metadata": {},
   "source": [
    "# LaB-GATr Basic Usage\n",
    "\n",
    "Before using this notebook, install the correct dependencies and LaB-GATr as follows:\n",
    "\n",
    "Optional new Anaconda environment\n",
    "```\n",
    "conda create --name lab-gatr python=3.10\n",
    "conda activate lab-gatr\n",
    "``` \n",
    "Next, install PyTorch and xFormers and other libraries\n",
    "```\n",
    "pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121\n",
    "pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu121\n",
    "pip install torch_geometric==2.4.0\n",
    "pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-2.1.0+cu121.html\n",
    "```\n",
    "Install LaB-GATr itself, which also install GATr\n",
    "```\n",
    "pip install .\n",
    "```\n",
    "Additionally, if you have made a new Anaconda environment, install Jupyter\n",
    "```\n",
    "pip install jupyter jupyterlab\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {},
   "source": [
    "import torch\n",
    "from lab_gatr import PointCloudPoolingScales, LaBGATr\n",
    "import torch_geometric as pyg\n",
    "from gatr.interface import embed_oriented_plane, extract_translation"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "89de2d017dd9efa3",
   "metadata": {},
   "source": [
    "Let us first create a dummy mesh: n positions and orientations (e.g. surface normal) and an arbitrary scalar feature (e.g. geodesic distance)."
   ]
  },
  {
   "cell_type": "code",
   "id": "7264db0cc4bfa961",
   "metadata": {},
   "source": [
    "n = 1000\n",
    "\n",
    "pos, orientation = torch.rand((n, 3)), torch.rand((n, 3))\n",
    "scalar_feature = torch.rand(n)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "cb6a1fa784bd6801",
   "metadata": {},
   "source": [
    "Next, a point cloud pooling transform for the tokenisation (patching). Also apply this as a transform to create the dummy data."
   ]
  },
  {
   "cell_type": "code",
   "id": "ac963dbde04fc5c5",
   "metadata": {},
   "source": [
    "transform = PointCloudPoolingScales(rel_sampling_ratios=(0.2,), interp_simplex='triangle')\n",
    "dummy_data = transform(pyg.data.Data(pos=pos, orientation=orientation, scalar_feature=scalar_feature))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "546e0c748430850a",
   "metadata": {},
   "source": [
    "A geometric algebra interface to embed your data in $\\mathbf{G}(3, 0, 1)$."
   ]
  },
  {
   "cell_type": "code",
   "id": "375b5f1d6f0c856b",
   "metadata": {},
   "source": [
    "class GeometricAlgebraInterface:\n",
    "    num_input_channels = num_output_channels = 1\n",
    "    num_input_scalars = num_output_scalars = 1\n",
    "\n",
    "    @staticmethod\n",
    "    @torch.no_grad()\n",
    "    def embed(data):\n",
    "\n",
    "        multivectors = embed_oriented_plane(normal=data.orientation, position=data.pos).view(-1, 1, 16)\n",
    "        scalars = data.scalar_feature.view(-1, 1)\n",
    "\n",
    "        return multivectors, scalars\n",
    "\n",
    "    @staticmethod\n",
    "    def dislodge(multivectors, scalars):\n",
    "        return extract_translation(multivectors).squeeze()\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "c832f5754fbd5911",
   "metadata": {},
   "source": [
    "Create a model instance with dimensionality 8, 10 GATr blocks, and 4 attention heads in each block. It uses the geometric algebra interface we defined above."
   ]
  },
  {
   "cell_type": "code",
   "id": "63280835758baa8f",
   "metadata": {},
   "source": [
    "model = LaBGATr(GeometricAlgebraInterface, d_model=8, num_blocks=10, num_attn_heads=4, use_class_token=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "38764051e0c48259",
   "metadata": {},
   "source": [
    "Generate some output with the dummy data to verify that the model functions. Training or inference from here on is the same as any PyTorch model."
   ]
  },
  {
   "cell_type": "code",
   "id": "4dfd82278687c9c4",
   "metadata": {},
   "source": [
    "output = model(dummy_data)\n",
    "print(output.shape)"
   ],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
